# --------------------------------------------------------------------

class SolverType:
    """
    TAO Solver Types
    """
    LMVM     = S_(TAO_LMVM)
    NLS      = S_(TAO_NLS)
    NTR      = S_(TAO_NTR)
    NTL      = S_(TAO_NTL)
    CG       = S_(TAO_CG)
    TRON     = S_(TAO_TRON)
    BLMVM    = S_(TAO_BLMVM)
    BQPIP    = S_(TAO_BQPIP)
    GPCG     = S_(TAO_GPCG)
    NM       = S_(TAO_NM)
    POUNDERS = S_(TAO_POUNDERS)
    LCL      = S_(TAO_LCL)
    SSILS    = S_(TAO_SSILS)
    SSFLS    = S_(TAO_SSFLS)
    ASILS    = S_(TAO_ASILS)
    ASFLS    = S_(TAO_ASFLS)
    FD_TEST  = S_(TAO_FD_TEST)

class SolverConvergedReason:
    """
    TAO Solver Termination Reasons
    """
    # iterating
    CONTINUE_ITERATING    = TAO_CONTINUE_ITERATING    # iterating
    CONVERGED_ITERATING   = TAO_CONTINUE_ITERATING    # iterating
    ITERATING             = TAO_CONTINUE_ITERATING    # iterating
    # converged
    CONVERGED_FATOL       = TAO_CONVERGED_FATOL       # f(X)-f(X*) <= fatol
    CONVERGED_FRTOL       = TAO_CONVERGED_FRTOL       # |F(X)-f(X*)|/|f(X)| < frtol
    CONVERGED_GATOL       = TAO_CONVERGED_GATOL       # ||g(X)|| < gatol
    CONVERGED_GRTOL       = TAO_CONVERGED_GRTOL       # ||g(X)||/f(X)  < grtol
    CONVERGED_GTTOL       = TAO_CONVERGED_GTTOL       # ||g(X)||/||g(X0)|| < gttol
    CONVERGED_STEPTOL     = TAO_CONVERGED_STEPTOL     # small step size
    CONVERGED_MINF        = TAO_CONVERGED_MINF        # f(X) < F_min 
    CONVERGED_USER        = TAO_CONVERGED_USER        # user defined
    # diverged
    DIVERGED_MAXITS       = TAO_DIVERGED_MAXITS       #
    DIVERGED_NAN          = TAO_DIVERGED_NAN          #
    DIVERGED_MAXFCN       = TAO_DIVERGED_MAXFCN       #
    DIVERGED_LS_FAILURE   = TAO_DIVERGED_LS_FAILURE   #
    DIVERGED_TR_REDUCTION = TAO_DIVERGED_TR_REDUCTION #
    DIVERGED_USER         = TAO_DIVERGED_USER         # user defined

# --------------------------------------------------------------------

cdef class Solver(Object):

    """
    TAO Solver
    """

    Type   = SolverType
    Reason = SolverConvergedReason

    def __cinit__(self):
        self.obj = <PetscObject*> &self.slv
        self.slv = NULL

    def __dealloc__(self):
        pass

    def view(self, Viewer viewer=None):
        """
        """
        cdef PetscViewer vwr = NULL
        if viewer is not None: vwr = viewer.vwr
        CHKERR( TaoView(self.slv, vwr) )

    def destroy(self):
        """
        """
        CHKERR( TaoDestroy(&self.slv) )
        return self

    def create(self, comm=None):
        """
        """
        cdef MPI_Comm ccomm = def_Comm(comm, GetCommDefault())
        cdef TaoSolver newslv = NULL
        CHKERR( TaoCreate(ccomm, &newslv) )
        PetscCLEAR(self.obj); self.slv = newslv
        return self

    def setType(self, solver_type):
        """
        """
        cdef TaoSolverType ctype = NULL
        solver_type = str2bytes(solver_type, &ctype)
        CHKERR( TaoSetType(self.slv, ctype) )

    def getType(self):
        """
        """
        cdef TaoSolverType ctype = NULL
        CHKERR( TaoGetType(self.slv, &ctype) )
        return bytes2str(ctype)

    def setOptionsPrefix(self, prefix):
        """
        """
        cdef const_char *cprefix = NULL
        prefix = str2bytes(prefix, &cprefix)
        CHKERR( TaoSetOptionsPrefix(self.slv, cprefix) )

    def getOptionsPrefix(self):
        """
        """
        cdef const_char *prefix = NULL
        CHKERR( TaoGetOptionsPrefix(self.slv, &prefix) )
        return bytes2str(prefix)

    def setFromOptions(self):
        """
        """
        CHKERR( TaoSetFromOptions(self.slv) )

    def setUp(self):
        """
        """
        CHKERR( TaoSetUp(self.slv) )

    #

    # --------------

    def setAppCtx(self, appctx):
        self.set_attr("__appctx__", appctx)

    def getAppCtx(self):
        return self.get_attr("__appctx__")

    def setInitial(self, Vec x not None):
        """
        """
        CHKERR( TaoSetInitialVector(self.slv, x.vec) )

    def setObjective(self, objective, args=None, kargs=None):
        """
        """
        CHKERR( TaoSetObjectiveRoutine(self.slv, Tao_Objective, NULL) )
        if args is None: args = ()
        if kargs is None: kargs = {}
        self.set_attr("__objective__", (objective, args, kargs))

    def setSeparableObjective(self, separable, Vec O=None,
                              args=None, kargs=None):
        """
        """
        cdef PetscVec Ovec = NULL
        if O is not None: Ovec = O.vec
        CHKERR( TaoSetSeparableObjectiveRoutine(self.slv, Ovec,
                                                Tao_SeparableObjective, NULL) )
        CHKERR( PetscObjectCompose(self.obj[0], "@sepobjvec", <PetscObject>Ovec) )
        if args is None: args = ()
        if kargs is None: kargs = {}
        self.set_attr("__separable__", (separable, args, kargs))

    def setGradient(self, gradient, args=None, kargs=None):
        """
        """
        CHKERR( TaoSetGradientRoutine(self.slv, Tao_Gradient, NULL) )
        if args is None: args = ()
        if kargs is None: kargs = {}
        self.set_attr("__gradient__", (gradient, args, kargs))

    def setObjectiveGradient(self, objgrad, args=None, kargs=None):
        """
        """
        CHKERR( TaoSetObjectiveAndGradientRoutine(self.slv, Tao_ObjGrad, NULL) )
        if args is None: args = ()
        if kargs is None: kargs = {}
        self.set_attr("__objgrad__", (objgrad, args, kargs))

    def setVariableBounds(self, varbounds, args=None, kargs=None):
        """
        """
        cdef Vec xl=None, xu=None
        if (isinstance(varbounds, list) or isinstance(varbounds, tuple)):
            ol, ou = varbounds
            xl = <Vec?> ol; xu = <Vec?> ou
            CHKERR( TaoSetVariableBounds(self.slv, xl.vec, xu.vec) )
            return
        if isinstance(varbounds, Vec):
            ol = varbounds; ou = args
            xl = <Vec?> ol; xu = <Vec?> ou
            CHKERR( TaoSetVariableBounds(self.slv, xl.vec, xu.vec) )
            return
        CHKERR( TaoSetVariableBoundsRoutine(self.slv, Tao_VarBounds, NULL) )
        if args is None: args = ()
        if kargs is None: kargs = {}
        self.set_attr("__varbounds__", (varbounds, args, kargs))

    def setConstraints(self, constraints, Vec C=None,
                       args=None, kargs=None):
        """
        """
        cdef PetscVec Cvec=NULL
        if C is not None: Cvec = C.vec
        CHKERR( TaoSetConstraintsRoutine(self.slv, Cvec, Tao_Constraints, NULL) )
        if args is None: args = ()
        if kargs is None: kargs = {}
        self.set_attr("__constraints__", (constraints, args, kargs))

    def setHessian(self, hessian, Mat H=None, Mat P=None,
                   args=None, kargs=None):
        cdef PetscMat Hmat=NULL
        if H is not None: Hmat = H.mat
        cdef PetscMat Pmat = Hmat
        if P is not None: Pmat = P.mat
        
        CHKERR( TaoSetHessianRoutine(self.slv, Hmat, Pmat, Tao_Hessian, NULL) )
        if args is None: args = ()
        if kargs is None: kargs = {}
        self.set_attr("__hessian__", (hessian, args, kargs))
    
    def setJacobian(self, jacobian, Mat J=None, Mat P=None,
                    args=None, kargs=None):
        """
        """
        cdef PetscMat Jmat=NULL
        if J is not None: Jmat = J.mat
        cdef PetscMat Pmat = Jmat
        if P is not None: Pmat = P.mat
        CHKERR( TaoSetJacobianRoutine(self.slv, Jmat, Pmat, Tao_Jacobian, NULL) )
        if args is None: args = ()
        if kargs is None: kargs = {}
        self.set_attr("__jacobian__", (jacobian, args, kargs))

    #
    
    def setStateDesignIS(self, IS state, IS design):
        """
        """
        cdef PetscIS s_is = NULL, d_is = NULL
        if state  is not None: s_is = state.iset
        if design is not None: d_is = design.iset
        CHKERR( TaoSetStateDesignIS(self.slv, s_is, d_is) )

    def setJacobianState(self, jacobian_state, Mat J=None, Mat P=None, Mat I=None,
                         args=None, kargs=None):
        """
        """
        cdef PetscMat Jmat=NULL
        if J is not None: Jmat = J.mat
        cdef PetscMat Pmat = Jmat
        if P is not None: Pmat = P.mat
        cdef PetscMat Imat = NULL
        if I is not None: Imat = I.mat
        CHKERR( TaoSetJacobianStateRoutine(self.slv, Jmat, Pmat, Imat,
                                           Tao_JacobianState, NULL) )
        if args is None: args = ()
        if kargs is None: kargs = {}
        self.set_attr("__jacobian_state__", (jacobian_state, args, kargs))

    def setJacobianDesign(self, jacobian_design, Mat J=None,
                          args=None, kargs=None):
        """
        """
        cdef PetscMat Jmat=NULL
        if J is not None: Jmat = J.mat
        CHKERR( TaoSetJacobianDesignRoutine(self.slv, Jmat,
                                            Tao_JacobianDesign, NULL) )
        if args is None: args = ()
        if kargs is None: kargs = {}
        self.set_attr("__jacobian_design__", (jacobian_design, args, kargs))
    
    # --------------

    def computeObjective(self, Vec x not None):
        """
        """
        cdef PetscReal f = 0
        CHKERR( TaoComputeObjective(self.slv, x.vec, &f) )
        return f

    def computeSeparableObjective(self, Vec x not None, Vec f not None):
        """
        """
        CHKERR( TaoComputeSeparableObjective(self.slv, x.vec, f.vec) )

    def computeGradient(self, Vec x not None, Vec g not None):
        """
        """
        CHKERR( TaoComputeGradient(self.slv, x.vec, g.vec) )

    def computeObjectiveGradient(self, Vec x not None, Vec g not None):
        """
        """
        cdef PetscReal f = 0
        CHKERR( TaoComputeObjectiveAndGradient(self.slv, x.vec, &f, g.vec) )
        return f

    def computeDualVariables(self, Vec xl not None, Vec xu not None):
        """
        """
        CHKERR( TaoComputeDualVariables(self.slv, xl.vec, xu.vec) )

    def computeVariableBounds(self, Vec xl not None, Vec xu not None):
        """
        """
        CHKERR( TaoComputeVariableBounds(self.slv) )
        cdef PetscVec Lvec = NULL, Uvec = NULL
        CHKERR( TaoGetVariableBounds(self.slv, &Lvec, &Uvec) )
        if xl.vec != NULL:
            if Lvec != NULL:
                CHKERR( VecCopy(Lvec, xl.vec) )
            else:
                CHKERR( VecSet(xl.vec, <PetscScalar>TAO_NINFINITY) )
        if xu.vec != NULL:
            if Uvec != NULL:
                CHKERR( VecCopy(Uvec, xu.vec) )
            else:
                CHKERR( VecSet(xu.vec, <PetscScalar>TAO_INFINITY) )
                

    def computeConstraints(self, Vec x not None, Vec c not None):
        """
        """
        CHKERR( TaoComputeConstraints(self.slv, x.vec, c.vec) )

    def computeHessian(self, Vec x not None, Mat H not None, Mat P=None):
        """
        """
        cdef PetscMat *hmat = &H.mat, *pmat = &H.mat
        if P is not None: pmat = &P.mat
        cdef PetscMatStructure flag = MAT_DIFFERENT_NONZERO_PATTERN
        CHKERR( TaoComputeHessian(self.slv, x.vec, hmat, pmat, &flag) )
        return flag

    def computeJacobian(self, Vec x not None, Mat J not None, Mat P=None):
        """
        """
        cdef PetscMat *jmat = &J.mat, *pmat = &J.mat
        if P is not None: pmat = &P.mat
        cdef PetscMatStructure flag = MAT_DIFFERENT_NONZERO_PATTERN
        CHKERR( TaoComputeJacobian(self.slv, x.vec, jmat, pmat, &flag) )
        return flag

    # --------------

    #

    def setTolerances(self,
                      fatol=None, frtol=None,
                      gatol=None, grtol=None, gttol=None):
        """
        """
        cdef PetscReal _fatol=PETSC_DEFAULT, _frtol=PETSC_DEFAULT
        if fatol is not None: _fatol = fatol
        if frtol is not None: _frtol = frtol
        cdef PetscReal _gatol=PETSC_DEFAULT, _grtol=PETSC_DEFAULT, _gttol=PETSC_DEFAULT
        if gatol is not None: _gatol = gatol
        if grtol is not None: _grtol = grtol
        if gttol is not None: _gttol = gttol
        CHKERR( TaoSetTolerances(self.slv, _fatol, _frtol, _gatol, _grtol, _gttol) )

    def getTolerances(self):
        """
        """
        cdef PetscReal _fatol=PETSC_DEFAULT, _frtol=PETSC_DEFAULT
        cdef PetscReal _gatol=PETSC_DEFAULT, _grtol=PETSC_DEFAULT, _gttol=PETSC_DEFAULT
        CHKERR( TaoGetTolerances(self.slv, &_fatol, &_frtol, &_gatol, &_grtol, &_gttol) )
        return (_fatol, _frtol, _gatol, _grtol, _gttol)

    def setObjectiveTolerances(self, fatol=None, frtol=None):
        """
        """
        cdef PetscReal _fatol=PETSC_DEFAULT, _frtol=PETSC_DEFAULT
        cdef PetscReal _gatol=PETSC_DEFAULT, _grtol=PETSC_DEFAULT, _gttol=PETSC_DEFAULT
        if fatol is not None: _fatol = fatol
        if frtol is not None: _frtol = frtol
        CHKERR( TaoSetTolerances(self.slv, _fatol, _frtol, _gatol, _grtol, _gttol) )

    def getObjectiveTolerances(self):
        """
        """
        cdef PetscReal _fatol=PETSC_DEFAULT, _frtol=PETSC_DEFAULT
        cdef PetscReal _gatol=PETSC_DEFAULT, _grtol=PETSC_DEFAULT, _gttol=PETSC_DEFAULT
        CHKERR( TaoGetTolerances(self.slv, &_fatol, &_frtol, &_gatol, &_grtol, &_gttol) )
        return (_fatol, _frtol)

    setFunctionTolerances = setObjectiveTolerances
    getFunctionTolerances = getObjectiveTolerances

    def setGradientTolerances(self, gatol=None, grtol=None, gttol=None):
        """
        """
        cdef PetscReal _fatol=PETSC_DEFAULT, _frtol=PETSC_DEFAULT
        cdef PetscReal _gatol=PETSC_DEFAULT, _grtol=PETSC_DEFAULT, _gttol=PETSC_DEFAULT
        if gatol is not None: _gatol = gatol
        if grtol is not None: _grtol = grtol
        if gttol is not None: _gttol = gttol
        CHKERR( TaoSetTolerances(self.slv, _fatol, _frtol, _gatol, _grtol, _gttol) )

    def getGradientTolerances(self):
        """
        """
        cdef PetscReal _fatol=PETSC_DEFAULT, _frtol=PETSC_DEFAULT
        cdef PetscReal _gatol=PETSC_DEFAULT, _grtol=PETSC_DEFAULT, _gttol=PETSC_DEFAULT
        CHKERR( TaoGetTolerances(self.slv, &_fatol, &_frtol, &_gatol, &_grtol, &_gttol) )
        return (_gatol, _grtol, _gttol)

    def setConstraintTolerances(self, catol=None, crtol=None):
        """
        """
        cdef PetscReal _catol=PETSC_DEFAULT, _crtol=PETSC_DEFAULT
        if catol is not None: _catol = catol
        if crtol is not None: _crtol = crtol
        CHKERR( TaoSetConstraintTolerances(self.slv, _catol, _crtol) )

    def getConstraintTolerances(self):
        """
        """
        cdef PetscReal _catol=PETSC_DEFAULT, _crtol=PETSC_DEFAULT
        CHKERR( TaoGetConstraintTolerances(self.slv, &_catol, &_crtol) )
        return (_catol, _crtol)

    def setConvergenceTest(self, converged, args=None, kargs=None):
        """
        """
        if converged is None: 
            CHKERR( TaoSetConvergenceTest(self.slv, TaoDefaultConvergenceTest, NULL) )
            self.set_attr('__converged__', None)
        else:
            if args is None: args = ()
            if kargs is None: kargs = {}
            self.set_attr('__converged__', (converged, args, kargs))
            CHKERR( TaoSetConvergenceTest(self.slv, Tao_Converged, NULL) )

    def getConvergenceTest(self):
        """
        """
        return self.get_attr('__converged__')

    def setConvergedReason(self, reason):
        """
        """
        cdef TaoSolverTerminationReason creason = reason
        CHKERR( TaoSetTerminationReason(self.slv, creason) )

    def getConvergedReason(self):
        """
        """
        cdef TaoSolverTerminationReason creason = TAO_CONTINUE_ITERATING
        CHKERR( TaoGetTerminationReason(self.slv, &creason) )
        return creason

    def setMonitor(self, monitor, args=None, kargs=None):
        """
        """
        if monitor is None: return
        cdef object monitorlist = self.get_attr('__monitor__')
        if monitorlist is None:
            CHKERR( TaoSetMonitor(self.slv, Tao_Monitor, NULL, NULL) )
            if args  is None: args  = ()
            if kargs is None: kargs = {}
            self.set_attr('__monitor__',  [(monitor, args, kargs)])
        else:
            monitorlist.append((monitor, args, kargs))

    def getMonitor(self):
        """
        """
        return self.get_attr('__monitor__')

    def cancelMonitor(self):
        """
        """
        CHKERR( TaoCancelMonitors(self.slv) )
        self.set_attr('__monitor__',  None)

    #

    def solve(self, Vec x=None):
        """
        """
        if x is not None:
            CHKERR( TaoSetInitialVector(self.slv, x.vec) )
        CHKERR( TaoSolve(self.slv) )

    def getSolution(self):
        """
        """
        cdef Vec vec = Vec()
        CHKERR( TaoGetSolutionVector(self.slv, &vec.vec) )
        PetscINCREF(vec.obj)
        return vec

    def getGradient(self):
        """
        """
        cdef Vec vec = Vec()
        CHKERR( TaoGetGradientVector(self.slv, &vec.vec) )
        PetscINCREF(vec.obj)
        return vec

    def getVariableBounds(self):
        """
        """
        cdef Vec xl = Vec(), xu = Vec()
        CHKERR( TaoGetVariableBounds(self.slv, &xl.vec, &xu.vec) )
        PetscINCREF(xl.obj); PetscINCREF(xu.obj)
        return (xl, xu)

    def getIterationNumber(self):
        """
        """
        cdef PetscInt its=0
        CHKERR( TaoGetSolutionStatus(self.slv, &its, NULL, NULL, NULL, NULL, NULL) )
        return its

    def getObjectiveValue(self):
        """
        """
        cdef PetscReal fval=0
        CHKERR( TaoGetSolutionStatus(self.slv, NULL, &fval, NULL, NULL, NULL, NULL) )
        return fval

    getFunctionValue = getObjectiveValue

    def getGradientNorm(self):
        """
        """
        cdef PetscReal gnorm=0
        CHKERR( TaoGetSolutionStatus(self.slv, NULL, NULL, &gnorm, NULL, NULL, NULL) )
        return gnorm

    def getConstraintsNorm(self):
        """
        """
        cdef PetscReal cnorm=0
        CHKERR( TaoGetSolutionStatus(self.slv, NULL, NULL, NULL, &cnorm, NULL, NULL) )
        return cnorm

    def getTerminationReason(self):
        """
        """
        cdef TaoSolverTerminationReason reason = TAO_CONTINUE_ITERATING
        CHKERR( TaoGetTerminationReason(self.slv, &reason) )
        return reason

    def getSolutionStatus(self):
        """
        """
        cdef PetscInt its=0
        cdef PetscReal fval=0, gnorm=0, cnorm=0, xdiff=0
        cdef TaoSolverTerminationReason reason = TAO_CONTINUE_ITERATING
        CHKERR( TaoGetSolutionStatus(self.slv, &its,
                                     &fval, &gnorm, &cnorm, &xdiff,
                                     &reason) )
        return (its, fval, gnorm, cnorm, xdiff, reason)

    def getKSP(self):
        """
        """
        cdef PetscKSP ksp = NULL
        CHKERR( TaoGetKSP(self.slv, &ksp) )
        return new_KSP(ksp)

    # --- application context ---

    property appctx:
        def __get__(self):
            return self.getAppCtx()
        def __set__(self, value):
            self.setAppCtx(value)

    # --- linear solver ---

    property ksp:
        def __get__(self):
            return self.getKSP()

    # --- tolerances ---

    property ftol:
        def __get__(self):
            return self.getFunctionTolerances()
        def __set__(self, value):
            if isinstance(value, (tuple, list)):
                self.setFunctionTolerances(*value)
            elif isinstance(value, dict):
                self.setFunctionTolerances(**value)
            else:
                raise TypeError("expecting tuple/list or dict")

    property gtol:
        def __get__(self):
            return self.getGradientTolerances()
        def __set__(self, value):
            if isinstance(value, (tuple, list)):
                self.getGradientTolerances(*value)
            elif isinstance(value, dict):
                self.getGradientTolerances(**value)
            else:
                raise TypeError("expecting tuple/list or dict")

    property ctol:
        def __get__(self):
            return self.getConstraintTolerances()
        def __set__(self, value):
            if isinstance(value, (tuple, list)):
                self.getConstraintTolerances(*value)
            elif isinstance(value, dict):
                self.getConstraintTolerances(**value)
            else:
                raise TypeError("expecting tuple/list or dict")

    # --- iteration ---

    property its:
        def __get__(self):
            return self.getIterationNumber()

    property gnorm:
        def __get__(self):
            return self.getGradientNorm()

    property cnorm:
        def __get__(self):
            return self.getConstraintNorm()

    property solution:
        def __get__(self):
            return self.getSolution()

    property objective:
        def __get__(self):
            return self.getObjectiveValue()

    property function:
        def __get__(self):
            return self.getFunctionValue()

    property gradient:
        def __get__(self):
            return self.getGradient()

    # --- convergence ---

    property reason:
        def __get__(self):
            return self.getConvergedReason()

    property iterating:
        def __get__(self):
            return self.reason == 0

    property converged:
        def __get__(self):
            return self.reason > 0

    property diverged:
        def __get__(self):
            return self.reason < 0

# --------------------------------------------------------------------

del SolverType
del SolverConvergedReason

# --------------------------------------------------------------------
